{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "gS1TV7cOYnWM"
   },
   "source": [
    "# Emotion Classification in short texts with BERT\n",
    "\n",
    "Applying BERT to the problem of multiclass text classification. Our dataset consists of written dialogs, messages and short stories. Each dialog utterance/message is labeled with one of the five emotion categories: joy, anger, sadness, fear, neutral. \n",
    "\n",
    "## Workflow: \n",
    "1. Import Data\n",
    "2. Data preprocessing and downloading BERT\n",
    "3. Training and validation\n",
    "4. Saving the model\n",
    "\n",
    "Multiclass text classification with BERT and [ktrain](https://github.com/amaiya/ktrain). Use google colab for a free GPU \n",
    "\n",
    "👋  **Let's start** "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 819
    },
    "colab_type": "code",
    "id": "3ckm3TIdYnxF",
    "outputId": "e466435e-bfae-4312-e17b-63f4c76e15c9"
   },
   "outputs": [],
   "source": [
    "# install ktrain on Google Colab\n",
    "!pip3 install ktrain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 97
    },
    "colab_type": "code",
    "id": "FYJa3hJiYnWP",
    "outputId": "0f7f0dec-cdf8-48a5-fe4c-fb7695f76856"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<p style=\"color: red;\">\n",
       "The default version of TensorFlow in Colab will soon switch to TensorFlow 2.x.<br>\n",
       "We recommend you <a href=\"https://www.tensorflow.org/guide/migrate\" target=\"_blank\">upgrade</a> now \n",
       "or ensure your notebook will continue to use TensorFlow 1.x via the <code>%tensorflow_version 1.x</code> magic:\n",
       "<a href=\"https://colab.research.google.com/notebooks/tensorflow_version.ipynb\" target=\"_blank\">more info</a>.</p>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "using Keras version: 2.2.4-tf\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "import ktrain\n",
    "from ktrain import text"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "33Sa8kVvYnWR"
   },
   "source": [
    "## 1. Import Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 495
    },
    "colab_type": "code",
    "id": "xGRw5awSX7mZ",
    "outputId": "b0b110b6-a3ed-48a3-e46f-5071475dffde"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "size of training set: 7934\n",
      "size of validation set: 3393\n",
      "joy        2326\n",
      "sadness    2317\n",
      "anger      2259\n",
      "neutral    2254\n",
      "fear       2171\n",
      "Name: Emotion, dtype: int64\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Emotion</th>\n",
       "      <th>Text</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>neutral</td>\n",
       "      <td>There are tons of other paintings that I thin...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>sadness</td>\n",
       "      <td>Yet the dog had grown old and less capable , a...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>fear</td>\n",
       "      <td>When I get into the tube or the train without ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>fear</td>\n",
       "      <td>This last may be a source of considerable disq...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>anger</td>\n",
       "      <td>She disliked the intimacy he showed towards so...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>sadness</td>\n",
       "      <td>When my family heard that my Mother's cousin w...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>joy</td>\n",
       "      <td>Finding out I am chosen to collect norms for C...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>anger</td>\n",
       "      <td>A spokesperson said : ` Glen is furious that t...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>neutral</td>\n",
       "      <td>Yes .</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>sadness</td>\n",
       "      <td>When I see people with burns I feel sad, actua...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Emotion                                               Text\n",
       "0  neutral   There are tons of other paintings that I thin...\n",
       "1  sadness  Yet the dog had grown old and less capable , a...\n",
       "2     fear  When I get into the tube or the train without ...\n",
       "3     fear  This last may be a source of considerable disq...\n",
       "4    anger  She disliked the intimacy he showed towards so...\n",
       "5  sadness  When my family heard that my Mother's cousin w...\n",
       "6      joy  Finding out I am chosen to collect norms for C...\n",
       "7    anger  A spokesperson said : ` Glen is furious that t...\n",
       "8  neutral                                             Yes . \n",
       "9  sadness  When I see people with burns I feel sad, actua..."
      ]
     },
     "execution_count": 4,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_train = pd.read_csv('data/data_train.csv', encoding='utf-8')\n",
    "data_test = pd.read_csv('data/data_test.csv', encoding='utf-8')\n",
    "\n",
    "X_train = data_train.Text.tolist()\n",
    "X_test = data_test.Text.tolist()\n",
    "\n",
    "y_train = data_train.Emotion.tolist()\n",
    "y_test = data_test.Emotion.tolist()\n",
    "\n",
    "data = data_train.append(data_test, ignore_index=True)\n",
    "\n",
    "class_names = ['joy', 'sadness', 'fear', 'anger', 'neutral']\n",
    "\n",
    "print('size of training set: %s' % (len(data_train['Text'])))\n",
    "print('size of validation set: %s' % (len(data_test['Text'])))\n",
    "print(data.Emotion.value_counts())\n",
    "\n",
    "data.head(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "gBxd1kYYYnWU"
   },
   "outputs": [],
   "source": [
    "encoding = {\n",
    "    'joy': 0,\n",
    "    'sadness': 1,\n",
    "    'fear': 2,\n",
    "    'anger': 3,\n",
    "    'neutral': 4\n",
    "}\n",
    "\n",
    "# Integer values for each class\n",
    "y_train = [encoding[x] for x in y_train]\n",
    "y_test = [encoding[x] for x in y_test]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "pcuN5eklYnWW"
   },
   "source": [
    "## 2. Data preprocessing\n",
    "\n",
    "* The text must be preprocessed in a specific way for use with BERT. This is accomplished by setting preprocess_mode to ‘bert’. The BERT model and vocabulary will be automatically downloaded\n",
    "\n",
    "* BERT can handle a maximum length of 512, but let's use less to reduce memory and improve speed. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 255
    },
    "colab_type": "code",
    "id": "IQnZnKQmZA3U",
    "outputId": "38dcedee-706f-4d32-e578-976b8b55826c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "downloading pretrained BERT model (uncased_L-12_H-768_A-12.zip)...\n",
      "[██████████████████████████████████████████████████]\n",
      "extracting pretrained BERT model...\n",
      "done.\n",
      "\n",
      "cleanup downloaded zip...\n",
      "done.\n",
      "\n",
      "preprocessing train...\n",
      "language: en\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "done."
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "preprocessing test...\n",
      "language: en\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "done."
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "(x_train,  y_train), (x_test, y_test), preproc = text.texts_from_array(x_train=X_train, y_train=y_train,\n",
    "                                                                       x_test=X_test, y_test=y_test,\n",
    "                                                                       class_names=class_names,\n",
    "                                                                       preprocess_mode='bert',\n",
    "                                                                       maxlen=350, \n",
    "                                                                       max_features=35000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "V4fuvHPnYnWY"
   },
   "source": [
    "## 2. Training and validation\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "1KwBXM9BYnWZ"
   },
   "source": [
    "Loading the pretrained BERT for text classification "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 68
    },
    "colab_type": "code",
    "id": "1cGb2CaOZBNS",
    "outputId": "4d05d637-9692-45e3-8ada-0dab9603480e"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Is Multi-Label? False\n",
      "maxlen is 350\n",
      "done.\n"
     ]
    }
   ],
   "source": [
    "model = text.text_classifier('bert', train_data=(x_train, y_train), preproc=preproc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "JkgvS2pvYnWb"
   },
   "source": [
    "Wrap it in a Learner object"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "zu8uv-xKYnWc"
   },
   "outputs": [],
   "source": [
    "learner = ktrain.get_learner(model, train_data=(x_train, y_train), \n",
    "                             val_data=(x_test, y_test),\n",
    "                             batch_size=6)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "h4YY0JBAYnWd"
   },
   "source": [
    "Train the model. More about tuning learning rates [here](https://github.com/amaiya/ktrain/blob/master/tutorial-02-tuning-learning-rates.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 224
    },
    "colab_type": "code",
    "id": "jD-2RpgkZN_n",
    "outputId": "74ab6af9-e127-44b2-bb14-eb23b3ed17d3"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "begin training using onecycle policy with max lr of 2e-05...\n",
      "Train on 7934 samples, validate on 3393 samples\n",
      "Epoch 1/3\n",
      "7934/7934 [==============================] - 475s 60ms/sample - loss: 0.9311 - acc: 0.6364 - val_loss: 0.5669 - val_acc: 0.8034\n",
      "Epoch 2/3\n",
      "7934/7934 [==============================] - 466s 59ms/sample - loss: 0.4569 - acc: 0.8470 - val_loss: 0.5211 - val_acc: 0.8232\n",
      "Epoch 3/3\n",
      "7934/7934 [==============================] - 466s 59ms/sample - loss: 0.1911 - acc: 0.9411 - val_loss: 0.5589 - val_acc: 0.8320\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7ffa776ace10>"
      ]
     },
     "execution_count": 9,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learner.fit_onecycle(2e-5, 3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "xHhddIieYnWg"
   },
   "source": [
    "Validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 306
    },
    "colab_type": "code",
    "id": "2s4ao_e2i4ld",
    "outputId": "fd6d79ca-5d3d-41b0-86ff-a8d331e2847d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "         joy       0.87      0.85      0.86       707\n",
      "     sadness       0.84      0.79      0.82       676\n",
      "        fear       0.86      0.87      0.86       679\n",
      "       anger       0.81      0.80      0.81       693\n",
      "     neutral       0.78      0.85      0.81       638\n",
      "\n",
      "    accuracy                           0.83      3393\n",
      "   macro avg       0.83      0.83      0.83      3393\n",
      "weighted avg       0.83      0.83      0.83      3393\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[598,   8,  15,  13,  73],\n",
       "       [ 18, 537,  37,  54,  30],\n",
       "       [ 16,  20, 590,  40,  13],\n",
       "       [ 19,  49,  35, 557,  33],\n",
       "       [ 37,  24,  12,  24, 541]])"
      ]
     },
     "execution_count": 10,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learner.validate(val_data=(x_test, y_test), class_names=class_names)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "MI0nWZlhYnWi"
   },
   "source": [
    "#### Testing with other inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "6eDpFIoXjlE8",
    "outputId": "a59d84b8-ec69-4253-bfa3-7364bb791fb1"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['joy', 'sadness', 'fear', 'anger', 'neutral']"
      ]
     },
     "execution_count": 11,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictor = ktrain.get_predictor(learner.model, preproc)\n",
    "predictor.get_classes()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "w5KQfwX8xLLu",
    "outputId": "99d2ba47-87ba-4e96-c03f-558746e91ede"
   },
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predicted: sadness (0.06)\n"
     ]
    }
   ],
   "source": [
    "import time \n",
    "\n",
    "message = 'I just broke up with my boyfriend'\n",
    "\n",
    "start_time = time.time() \n",
    "prediction = predictor.predict(message)\n",
    "\n",
    "print('predicted: {} ({:.2f})'.format(prediction, (time.time() - start_time)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "rusM_SzpYnWm"
   },
   "source": [
    "## 4. Saving Bert model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "YBkGko4Sz2ef"
   },
   "outputs": [],
   "source": [
    "# let's save the predictor for later use\n",
    "predictor.save(\"models/bert_model\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "qQo2P03tYnWo"
   },
   "source": [
    "Done! to reload the predictor use: ktrain.load_predictor"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "bert.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
